import numpy as np
from torch.utils.data.sampler import WeightedRandomSampler


def get_weighted_sampler(dataset):
    """获取加权采样器以处理类别不平衡"""
    class_counts = np.bincount(dataset.labels)
    num_samples = len(dataset)
    class_weights = 1. / class_counts
    weights = class_weights[dataset.labels]

    return WeightedRandomSampler(
        weights,
        num_samples=num_samples,
        replacement=True
    )